{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# MONAI 201 tutorial\n",
    "\n",
    "In this tutorial we'll revisit the [MONAI 101 notebook](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and add more features representing best practice concepts. This will include evaluation and tensorboard handler techniques.\n",
    "\n",
    "These steps will be included in this tutorial, and each of them will take only a few lines of code:\n",
    "- Dataset download and Data pre-processing\n",
    "- Define a DenseNet-121 and run training\n",
    "- Run inference using SupervisedEvaluator\n",
    "\n",
    "This tutorial will use about 7GB of GPU memory and 10 minutes to run.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/monai_201.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, tqdm, tensorboard]\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import numpy as np\n",
    "import os\n",
    "from pathlib import Path\n",
    "import sys\n",
    "import tempfile\n",
    "import torch\n",
    "import ignite\n",
    "\n",
    "from monai.apps import MedNISTDataset\n",
    "from monai.config import print_config\n",
    "from monai.data import DataLoader\n",
    "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n",
    "from monai.handlers import (\n",
    "    StatsHandler,\n",
    "    TensorBoardStatsHandler,\n",
    "    ValidationHandler,\n",
    "    CheckpointSaver,\n",
    "    CheckpointLoader,\n",
    "    ClassificationSaver,\n",
    ")\n",
    "from monai.handlers.utils import from_engine\n",
    "from monai.inferers import SimpleInferer\n",
    "from monai.networks.nets import densenet121\n",
    "from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose, AsDiscreted\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/Data\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use MONAI transforms to preprocess data\n",
    "\n",
    "We'll first prepare the data very much like in the previous tutorial with the same transforms and dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-02-27 08:31:31,955 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n",
      "2024-02-27 08:31:31,955 - INFO - File exists: /workspace/Data/MedNIST.tar.gz, skipped downloading.\n",
      "2024-02-27 08:31:31,956 - INFO - Non-empty folder exists in /workspace/Data/MedNIST, skipped extracting.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading dataset:   0%|          | 0/47164 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading dataset: 100%|██████████| 47164/47164 [00:19<00:00, 2393.21it/s]\n",
      "Loading dataset: 100%|██████████| 5895/5895 [00:02<00:00, 2465.05it/s]\n"
     ]
    }
   ],
   "source": [
    "transform = Compose(\n",
    "    [\n",
    "        LoadImageD(keys=\"image\", image_only=True),\n",
    "        EnsureChannelFirstD(keys=\"image\"),\n",
    "        ScaleIntensityD(keys=\"image\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# If you use the MedNIST dataset, please acknowledge the source.\n",
    "dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"training\", download=True)\n",
    "valdata = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"validation\", download=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a network and a supervised trainer\n",
    "\n",
    "For training we have the same elements again and will slightly change the `SupervisedTrainer` by expanding its train_handlers. This upgrade will be beneficial for efficient utilization of TensorBoard.\n",
    "Furthermore, we introduce a `SupervisedEvaluator` object that will efficiently track model progress. Accompanied by `TensorBoardStatsHandler`, it will log statistics for TensorBoard, ensuring precise tracking and management."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 5\n",
    "save_interval = 2\n",
    "out_dir = \"./eval\"\n",
    "model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(\"cuda:0\")\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "\n",
    "evaluator = SupervisedEvaluator(\n",
    "    device=torch.device(\"cuda:0\"),\n",
    "    val_data_loader=DataLoader(valdata, batch_size=512, shuffle=False, num_workers=4),\n",
    "    network=model,\n",
    "    inferer=SimpleInferer(),\n",
    "    key_val_metric={\"val_acc\": ignite.metrics.Accuracy(from_engine([\"pred\", \"label\"]))},\n",
    "    val_handlers=[StatsHandler(iteration_log=False), TensorBoardStatsHandler(iteration_log=False)],\n",
    ")\n",
    "\n",
    "trainer = SupervisedTrainer(\n",
    "    device=torch.device(\"cuda:0\"),\n",
    "    max_epochs=max_epochs,\n",
    "    train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),\n",
    "    network=model,\n",
    "    optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),\n",
    "    loss_function=torch.nn.CrossEntropyLoss(),\n",
    "    inferer=SimpleInferer(),\n",
    "    train_handlers=[\n",
    "        ValidationHandler(validator=evaluator, epoch_level=True, interval=1),\n",
    "        CheckpointSaver(\n",
    "            save_dir=out_dir,\n",
    "            save_dict={\"model\": model},\n",
    "            save_interval=save_interval,\n",
    "            save_final=True,\n",
    "            final_filename=\"checkpoint.pt\",\n",
    "        ),\n",
    "        StatsHandler(),\n",
    "        TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View training in tensorboard\n",
    "\n",
    "Please uncomment the following cell to load tensorboard results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir ./runs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "First thing to do is to prepare the test dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir = Path(root_dir, \"MedNIST\")\n",
    "class_names = sorted(f\"{x.name}\" for x in dataset_dir.iterdir() if x.is_dir())\n",
    "testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"test\", download=False, runtime_cache=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we're going to establish a `SupervisedEvaluator`. This evaluator will process all the files in the specified directory and persist the results into a CSV file. Validation handlers (val_handlers) will be utilized to load the checkpoint file, providing an error if any file is unavailable, and they will also save the classification outcomes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:ignite.engine.engine.SupervisedEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs\n",
      "INFO:ignite.engine.engine.SupervisedEvaluator:Restored all variables from ./eval/checkpoint.pt\n",
      "INFO:ignite.engine.engine.SupervisedEvaluator:Epoch[1] Complete. Time taken: 00:01:24.338\n",
      "INFO:ignite.engine.engine.SupervisedEvaluator:Engine run complete. Time taken: 00:01:24.390\n"
     ]
    }
   ],
   "source": [
    "evaluator = SupervisedEvaluator(\n",
    "    device=torch.device(\"cuda:0\"),\n",
    "    val_data_loader=DataLoader(testdata, batch_size=1, num_workers=0),\n",
    "    network=model,\n",
    "    inferer=SimpleInferer(),\n",
    "    postprocessing=AsDiscreted(keys=\"pred\", argmax=True),\n",
    "    val_handlers=[\n",
    "        CheckpointLoader(load_path=f\"{out_dir}/checkpoint.pt\", load_dict={\"model\": model}),\n",
    "        ClassificationSaver(\n",
    "            batch_transform=lambda batch: batch[0][\"image\"].meta, output_transform=from_engine([\"pred\"])\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "\n",
    "evaluator.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the inference results are stored in a file named \"predictions.csv\". However, this output filename can be customized within the `ClassificationSaver` handler, according to your preferences.\n",
    "Upon examining the output, one can note that the second column corresponds to the predicted class. A more discernable interpretation can be achieved by using these values as indices mapped to our predefined list of class names."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/Data/MedNIST/AbdomenCT/006070.jpeg AbdomenCT\n",
      "/workspace/Data/MedNIST/BreastMRI/006574.jpeg BreastMRI\n",
      "/workspace/Data/MedNIST/ChestCT/009858.jpeg ChestCT\n",
      "/workspace/Data/MedNIST/CXR/007398.jpeg CXR\n",
      "/workspace/Data/MedNIST/Hand/005663.jpeg Hand\n",
      "/workspace/Data/MedNIST/HeadCT/006896.jpeg HeadCT\n",
      "/workspace/Data/MedNIST/HeadCT/007179.jpeg HeadCT\n",
      "/workspace/Data/MedNIST/CXR/001190.jpeg CXR\n",
      "/workspace/Data/MedNIST/ChestCT/005138.jpeg ChestCT\n",
      "/workspace/Data/MedNIST/BreastMRI/000023.jpeg BreastMRI\n"
     ]
    }
   ],
   "source": [
    "max_items_to_print = 10\n",
    "for fn, idx in np.loadtxt(\"./predictions.csv\", delimiter=\",\", dtype=str):\n",
    "    print(fn, class_names[int(float(idx))])\n",
    "    max_items_to_print -= 1\n",
    "    if max_items_to_print == 0:\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
